import tensorflow as tf
from cifar10 import cifar10_input
import numpy as np

'''
通过一个带有全局平均池化的的卷积神经网络对CIFAR数据集进行分类

本例使用全局平均池化层来代替传统的全连接层
使用了 3 个卷积层的同卷积操作，滤波器为 5x5，
每个卷积层后面都会跟一个步长为 2x2 的池化层，滤波器为 2x2
2层的卷积加池化后是输出为10个通道的卷积层
然后对这10个feature map进行全局平均池化，得到 10 个特征，
再对这10个特征进行softmax计算，其结果来代表最终分类
'''
# 导入数据
batch_size = 128
data_dir = './cifar/cifar-10-batches-bin'

images_train, labels_train = cifar10_input.inputs(eval_data=False, data_dir=data_dir, batch_size=batch_size)
images_test, labels_test = cifar10_input.inputs(eval_data=True, data_dir=data_dir, batch_size=batch_size)

'''
定义网络结构
对于权重 weight 统一使用函数 truncated_normal 来生成标准差为 0.1 的随机数来初始化
对于偏置值 biases 统一初始化为 0.1
卷积操作的函数中，统一进行同卷积操作，即步长为 1，padding='SAME'

池化层有两个函数：
1）一个放在卷积后面，取最大值的方法，步长为 2，padding='SAME'， 即将原尺寸的长和宽各除以2
2）一个放在最后一层，取平均值的方法，步长为最终生成的特征尺寸 6X6(24X24经过两次池化变成了6X6)，filter也为 6X6

倒数第二层是没有最大池化的卷积层，因为共有10类，所以卷积输出的是10个通道，并使其全局平均池化为10个节点
'''


def weight_init(shape):   # 传入 shape
    init = tf.truncated_normal(shape=shape, stddev=0.1)
    return tf.Variable(init)


def biases_init(shape):
    init = tf.constant(0.1, shape=shape)
    return tf.Variable(init)


def conv2d(input_, filter_):
    """
    padding 值为 'VALID' 的，表示边缘不填充
    padding 值为 'SAME'  的，表示便于边缘填充到卷积核可以达到图像的边缘
    """
    return tf.nn.conv2d(input_, filter_, strides=[1, 1, 1, 1], padding='SAME')


def max_pool_2x2(input_):
    return tf.nn.max_pool(input_, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')


def avg_pool_6x6(input_):
    return tf.nn.avg_pool(input_, ksize=[1, 6, 6, 1], strides=[1, 6, 6, 1], padding='SAME')


# 定义占位符
x = tf.placeholder(tf.float32, [None, 24, 24, 3])   # cifar data 的shape 为 24*24*3
y = tf.placeholder(tf.float32, [None, 10])   # 0~9 数字分类 => 10 类

w_conv1 = weight_init([5, 5, 3, 64])  # sobel算子，即卷积核(滤波器) 大小 5x5，3通道，64个卷积核
'''
out_height = in_height / strides_height  = 24 / 1 = 24
out_width = in_width / strides_width  = 24 /1 =24
padding_height = max((out_height-1)*strides_height + filter_height - in_height, 0) = max((24-1)*1+5-24, 0) = 4
padding_width = max((out_width-1)*strides_width + filter_width - in_width, 0) = max((24-1)*1+5-24, 0) = 4
padding_top = padding_height / 2 = 2
padding_bottom = padding_height - padding_top = 2
padding_left = padding_width / 2 = 2
padding_right = padding_width - padding_left = 2
'''
b_conv1 = biases_init([64])

x_image = tf.reshape(x, [-1, 24, 24, 3])

h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)   # shape (n, 12, 12, 64)

w_conv2 = weight_init([5, 5, 64, 64])
b_conv2 = biases_init([64])

h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)     # shape (n, 6, 6, 64)

w_conv3 = weight_init([5, 5, 64, 10])
b_conv3 = biases_init([10])
h_conv3 = tf.nn.relu(conv2d(h_pool2, w_conv3) + b_conv3)
nt_hpool3 = avg_pool_6x6(h_conv3)

nt_hpool3_flat = tf.reshape(nt_hpool3, [-1, 10])
y_conv = tf.nn.softmax(nt_hpool3_flat)

cross_entropy = -tf.reduce_sum(y*tf.log(y_conv))

train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))

# 启动训练
sess = tf.Session()
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)   # images_train:  (128, 24, 24, 3)

# print('images_train: ', images_train.shape)
for i in range(2000):
    image_batch, label_batch = sess.run([images_train, labels_train])   # label_batch: [0 5 6 0 1 2 5 ... 2 0 7 3 7]
    label_b = np.eye(10, dtype=float)[label_batch]   # label 转成 one_hot 编码

    train_step.run(feed_dict={x: image_batch, y: label_b}, session=sess)

    if i % 200 == 0:
        train_accuracy = accuracy.eval(feed_dict={x: image_batch, y: label_b}, session=sess)
        print('step %d, train_accuracy %g' % (i, train_accuracy))


# 评估结果
image_batch, label_batch = sess.run([images_test, labels_test])
label_b = np.eye(10, dtype=float)[label_batch]
test_accuracy = accuracy.eval(feed_dict={x: image_batch, y: label_b}, session=sess)
print('Finished!!!, test_accuracy %g' % (i, train_accuracy))
